{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WpBVeU0XX8Uk"
   },
   "source": [
    "<h1>Chapter 12 - Fine-tuning Generation Models</h1>\n",
    "<i>Exploring a two-step approach for fine-tuning generative LLMs.</i>\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\"><img src=\"https://img.shields.io/badge/Buy%20the%20Book!-grey?logo=amazon\"></a>\n",
    "<a href=\"https://www.oreilly.com/library/view/hands-on-large-language/9781098150952/\"><img src=\"https://img.shields.io/badge/O'Reilly-white.svg?logo=\"></a>\n",
    "<a href=\"https://github.com/HandsOnLLM/Hands-On-Large-Language-Models\"><img src=\"https://img.shields.io/badge/GitHub%20Repository-black?logo=github\"></a>\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HandsOnLLM/Hands-On-Large-Language-Models/blob/main/chapter12/Chapter%2012%20-%20Fine-tuning%20Generation%20Models.ipynb)\n",
    "\n",
    "---\n",
    "\n",
    "This notebook is for Chapter 12 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n",
    "\n",
    "---\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\">\n",
    "<img src=\"https://raw.githubusercontent.com/HandsOnLLM/Hands-On-Large-Language-Models/main/images/book_cover.png\" width=\"350\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [OPTIONAL] - Installing Packages on <img src=\"https://colab.google/static/images/icons/colab.png\" width=100>\n",
    "\n",
    "If you are viewing this notebook on Google Colab (or any other cloud vendor), you need to **uncomment and run** the following codeblock to install the dependencies for this chapter:\n",
    "\n",
    "---\n",
    "\n",
    "💡 **NOTE**: We will want to use a GPU to run the examples in this notebook. In Google Colab, go to\n",
    "**Runtime > Change runtime type > Hardware accelerator > GPU > GPU type > T4**.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "# !pip install -q accelerate==0.31.0 peft==0.11.1 bitsandbytes==0.43.1 transformers==4.41.2 trl==0.9.4 sentencepiece==0.2.0 triton==3.1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v5luSSUAu_6d"
   },
   "source": [
    "# Supervised Fine-Tuning (SFT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VPtcbw38_hVi"
   },
   "source": [
    "## Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 717,
     "referenced_widgets": [
      "0c0285e0913b46638191933995384e81",
      "dcf360a0f47a49e4a45077254414c5a3",
      "15c6aefae79544f9976aaf48c76724f0",
      "14dfbe9d7ea64d948115cbfc419f088c",
      "27f6c79febad4975bad1c6826f56bb3a",
      "e41a3acb92354b39883d0a593cfd134f",
      "5580337a0b914e39997d375ab566d320",
      "0cd0b4ceab9d437f853687736c038724",
      "502d9bd6b8214cd5947b69581d16f43f",
      "a17a283f2efa46a39b5b08a3c9a3b354",
      "5632acd62d3f45a38317099f9d22ebb6",
      "9122ddf6e4c441aa888d237ab95f3db5",
      "7ef9099272e74f3ba85ccff9bec40fdd",
      "4bef0d8577984f5984779e4e80f541df",
      "e6fbcce852524d81af28fc0b308c416a",
      "997a77c0b71b45e785d07824116698ee",
      "cf32f48abc8a4b17a9ba384a50a04990",
      "b342ae0e84234300a0ce4c665ca34f0e",
      "2be641c16747464fa4030e8adfdc3d46",
      "cf001da1deae4f1e97454ea275c0f41a",
      "daa634044dd240efa142477a69643b81",
      "7697e11d19b54b0a895b89db5f1cab1c",
      "9aa1d03488e44364889b9b87db48369e",
      "d8c2e6bc0e4e4dff85217b4d6b294a4c",
      "e353f2f8d5f844658f90c252f35ab205",
      "1c6db3d796204b0587475bab3d3accc4",
      "0e0b34b158f745f0bf80c2b5d9f8047a",
      "a6394bf93d9847019bc69951a1343314",
      "3b4c717b24f04a1481b5ef9316abfda0",
      "72f3a7db32824b94987384879da3547b",
      "8c6d3817f69f433184e0554e66334202",
      "f322425b5f9842c39b083986866ddad9",
      "10962d2bcb2f442881e6be4cc822fb23",
      "ec75adaddca64c2d95ddcba24a3611f6",
      "571991d082184f39b8e75fcac8b40da9",
      "00303d6943e94fa3a18e180e119825e3",
      "564fd5d03ddc4db898de20119b596488",
      "6339b2becd1243eaa4bcc1dec80c70ce",
      "aefa2e2b6a3f4f75bc3ec84138a58272",
      "737fbc0ba9684ec68eeed5e7b4958a42",
      "01c18865664e471982c22916a8cad324",
      "7c2bfb3fc0a74868baae251eb02c1b5e",
      "e967b96a23994543ac1c57e86a32a72f",
      "3a7583cae9aa45b4b51a3630b07f176a",
      "a7b18257b8d445eab00a2ad90dde1935",
      "e01822a863864a878d7fc8308b3524b9",
      "c146636d840f4f81bb7650855621e779",
      "fbed7ecb20094e9a8cf0d35713561cd2",
      "0bd9340be6ba4edb9eb571ec25bf659c",
      "3f3ca64e868c41e18d07d47485d55b89",
      "315d7aa759cb4fc7a5234fe71fcb2f20",
      "b799a5db12964557b723e296b8682683",
      "cf2f19683db94c549f819ac419ba6cd6",
      "7f5806d80d484a83b7fb2fa21941b4b4",
      "a109942f24ed47ad9f7be28c56cdb1f3",
      "4c27c2ebc6c04157910d64034bfce031",
      "87f9cb64141c4afabedbc608ec00843e",
      "f558f81464b64db0bdc7f212ae39f2f7",
      "d29ea76455374ea1b16400c0108389cf",
      "a327cfa279b941309db9d501da4ab103",
      "7995c22adaba4f73a94da0dc49e06253",
      "21c143577ecc462b97251df3fdeee5c6",
      "aaa7114d60674c0b9439a9bb15d878ae",
      "728edb90e05e48419efde0fbc4b1854c",
      "0640421a36dc490f8ba89f550a913148",
      "0e72b56b079d4e20910bdc9fe1ea9192",
      "ebfdda7e72694e2e96dc5b894d624959",
      "98caa4ec23d44d08bb5c72756556eae2",
      "0df385f54f2a4b289b8fb453bcc05c89",
      "775f135d2ae3404fa3aa31bc3e137205",
      "40da08ae319e406c988ec665a40c7017",
      "7f9c6d434da14849b3d1965aa62083bf",
      "3cf04c5fa94f4400b8b71657abd1105a",
      "717649a83f6a440d89e49d675f2b035c",
      "3b3362cef58443e19c88cba029895229",
      "813b2c775ef34c20b9b2471b40d189b4",
      "1b19e64c2b3b444b93a3f50adb3ecbef",
      "8c2509bb3a3244f8bee5cd03b3a63b01",
      "5b44bf29b5f24f79aef39b189403bc4d",
      "ceb0cb33a74e464683351014cb2777b9",
      "75255917e2b8415a91e06eaf9261a432",
      "0f4374fe6fb946959bad87e95dc641b5",
      "67386e8a10e8437dba09bcd15b9ca95e",
      "f3db9002e1db4f1095c8d256d88b77bd",
      "6d288ede18de4c7daa8c24bbe3e734cb",
      "e514be007a6e413695acb9ab7541e1e7",
      "67bcaf2825584c40b56c22000b1ef813",
      "dd6840b67b064e998dbcc5a66a1924b1",
      "593b914a05834c71a737480e32e080af",
      "20bbacc57f0944efbc53beb5162e3949",
      "5295e8e550324fef87dbfd6c7e10d960",
      "14e82108cbc9491dbed6b426ca3238d6",
      "1a7bd5986d3949318b8be938be47cd75",
      "371426080f734157ba6c90fcb5b06d32",
      "64b9301aa9774da28ee49bed3e3a0c8e",
      "a97d7b4039f54cfd853c03ad1dc62e7f",
      "d98b4f9684cb46b5be80154f76a1695d",
      "1b0cfe5b09d549fa98b3a4abcbd0be42",
      "25de49bee947408e970a951b78fdf818",
      "2cd7a2846de944cb8f3f046042b1c259",
      "3644a994605e42b6885b41b6ad9d4039",
      "52278a1aa6344c798d8cca868ae72aa0",
      "65c42da7b6334852ab88e93d366c6d76",
      "3d3f29b0df5a479083856f3962dc15a8",
      "4461deb4578b414681111aa897e4cd6c",
      "2275c7f8c6424a479cb241213c5cee86",
      "8a6cb467bb3e4bd985948490bfe5a131",
      "12740f7ada4742e2b6d80b5bd6b74607",
      "f20d3572b6eb40f88436c28025b81bf7",
      "ee9a1521598640eba1a6ddec51fc2684",
      "c0270f68d92347cc8a0be3c40166cbe6",
      "754aefe850fd41b58695eaa2b11c2642",
      "1674cc41b1dd460aab86339c96ea25dc",
      "d0e398214f1041ad92f693b8d27b6aa2",
      "f9698c8009694677a3001f7d26df3282",
      "cd68b532faaf4339b38ee45f05067376",
      "0146e7107826418cbcf12224d1d3eff9",
      "abd89fffda5a4caf80ea321921699eed",
      "f583bba54fd44c9994b837a1d5a1f4cc",
      "1af514f472e241deb1d3110dee38ff1d",
      "ffc03d8ef2a247c1a0761c7180c7b0eb",
      "7068dee2697343a88b82236f35dbd325",
      "b4df107982fa482bb4951a3e1ea5ae3e",
      "ba287f57e76d404dbee0d5384d5eae9f",
      "893c0cf2bf71484d9f9a476ed7b7060b",
      "b5fade77a4c741ccae482a3085a79254",
      "be12786eefed4d72babd91fbef0703f1",
      "4432f74694924808b4d77fd80d925121",
      "18a2b862f3364f2a97adf727b9997830",
      "bdaedbfa2fe143b881de9ac928f260f1",
      "0541ddcf2e6d415aa11ed9be9518409e",
      "78270dd565a347dfb00cbb1ee54b474d",
      "473de04017ba465982c4e9e6a15d7ad9",
      "d2b3525bdd3944db9b00bd26addb2046",
      "2ddc6ec3271d4417a3d9b0bf43a66f60",
      "dfb5602808f64aeb901fe73f9d0ae3dd",
      "007a8d7830d345f19f044c377de3034f",
      "dc07ee7aa1e24c09aca0c8837fea0e2b",
      "4ac1d624ec9a4a6eab38801b1e8605ea",
      "3e1df78695db4902b88be100a5880d64",
      "88235119e9774ebfb91c16632420501d",
      "c9af312b93ef4a54a1bba402966891ed",
      "f3e1492ca4fc46c8b3e209fc53edc55b",
      "3de01e557b784bd387bd6985b05c58d2",
      "531aa8d491394608ad78374c96e1f182",
      "266bf17916d14b309031ef704ebf8b16",
      "034fe1a3c54c4aa5afab13b26946c390",
      "8fb7ab95392f41aba2884c67b4970b93",
      "5600fc1f67c64444b22f26b040e53bad",
      "603774211a8f44409f626aec41739f8b",
      "000129269bda41fc95af5d86e32037e1",
      "0d7ed05129774701a0e1fd5d9dee938e",
      "4fa7fbd7f10c43239619efe8c94891c2",
      "a77b8d97a3ca47fdabf6d3f154440332",
      "f78630d608e646a985d1af14bbe9108d",
      "89f8d133c512401caa43f4adc82923db",
      "38ed39fcb5854a5693cea58d2e7df3da",
      "5c22e529a1d345f58b2a35b932434ace",
      "ddd79e89dbd54530905cf3c11fedd4e6",
      "8db41e735e454cf187047c2453e6097d",
      "381293f50a7c41518b46e5dba81162b2",
      "658c9da8d4994da9bd6fc900959f79aa",
      "8d3012ca9c4a4b75bce993cf668fd062",
      "475082257c81404c8b1fd270f6dbf895",
      "ea68d2393dae4db6b7947d9589176c66",
      "f6768e31065f48598a625ed8831d69ac",
      "f782508c19df4dfba8a1e3bf14c50131",
      "9ea8370dddd1469daf46c6e9c9784c59",
      "ed09b731dfae40feb977fd982e92d237",
      "550634ce52b24efb9b29231ddd2bcdd8",
      "9d3cc2f714be4b8e9605f92682d7ac73",
      "a00545a641274994955018848dba8622",
      "c2928952588648fa824632bb43e75712",
      "6ae43605a1fb4dfaa38214c118201269",
      "57318597fd344d0a8bd0cef8855efe53",
      "5369f8381c0b43afa2bcb7193db21a39",
      "5994759ae74240b396b00df3d2e87caf",
      "6c2124867ce14599917c4f10a211911b",
      "4a575c47233e4722bc9252f8d650d014",
      "48869b2592504cc7a0d4dc5a8eb009ae",
      "3fa9beddb8304d3e8610923802ff6e25",
      "588bd78207204c74a4b9f45cd660f452",
      "af85f9be14e040ad9a38390d6629646d",
      "7b54eabb0cbe4966a0fdab4b3aab652d",
      "c93c5cb2419847c5a305c5376e3bcac1",
      "fac0d8b63fb44fd195c1b80d91037055",
      "504065bf78444cedaab7c54a4d6a888f",
      "79cb1e5865cb4cf1a3ec6592b2b70a87",
      "f659f53ffde94033b805b8166fc7b861",
      "61a6e7784a0d43d086ee4ac57fb138d4",
      "abd923ac9c324e76a8305bb624934ecc",
      "c5a657ab85c14fc9a8fe0d53e51edac9",
      "d285bec3cce74bfab2e54d764192cf94",
      "faf7e7fea65246a788edb30cc32f430e",
      "3301916017cc47ccacecdec36b37833d",
      "da88cae9dab740548ca0593208ff8e41",
      "05db7dee35e8409791c0b12711e92d4f",
      "a23266458ea94362bf5f3696429327ce"
     ]
    },
    "executionInfo": {
     "elapsed": 31435,
     "status": "ok",
     "timestamp": 1719390601273,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "SqeZchJiOXdd",
    "outputId": "516dbb28-1771-4e35-bc2a-3317a56960d8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c0285e0913b46638191933995384e81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9122ddf6e4c441aa888d237ab95f3db5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9aa1d03488e44364889b9b87db48369e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ec75adaddca64c2d95ddcba24a3611f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b18257b8d445eab00a2ad90dde1935",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/4.44k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c27c2ebc6c04157910d64034bfce031",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebfdda7e72694e2e96dc5b894d624959",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c2509bb3a3244f8bee5cd03b3a63b01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "593b914a05834c71a737480e32e080af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/81.2M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2cd7a2846de944cb8f3f046042b1c259",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/244M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c0270f68d92347cc8a0be3c40166cbe6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7068dee2697343a88b82236f35dbd325",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/243M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "473de04017ba465982c4e9e6a15d7ad9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/80.4M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3de01e557b784bd387bd6985b05c58d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train_sft split:   0%|          | 0/207865 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f78630d608e646a985d1af14bbe9108d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test_sft split:   0%|          | 0/23110 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6768e31065f48598a625ed8831d69ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train_gen split:   0%|          | 0/256032 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5994759ae74240b396b00df3d2e87caf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test_gen split:   0%|          | 0/28304 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79cb1e5865cb4cf1a3ec6592b2b70a87",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/3000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "# Load a tokenizer to use its chat template\n",
    "template_tokenizer = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
    "\n",
    "def format_prompt(example):\n",
    "    \"\"\"Format the prompt to using the <|user|> template TinyLLama is using\"\"\"\n",
    "\n",
    "    # Format answers\n",
    "    chat = example[\"messages\"]\n",
    "    prompt = template_tokenizer.apply_chat_template(chat, tokenize=False)\n",
    "\n",
    "    return {\"text\": prompt}\n",
    "\n",
    "# Load and format the data using the template TinyLLama is using\n",
    "dataset = (\n",
    "    load_dataset(\"HuggingFaceH4/ultrachat_200k\",  split=\"test_sft\")\n",
    "      .shuffle(seed=42)\n",
    "      .select(range(3_000))\n",
    ")\n",
    "dataset = dataset.map(format_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1719390601273,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "dtl2xZptgyDf",
    "outputId": "304c49f2-16c8-47ad-f8fb-4d975012e6d3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|user|>\n",
      "Given the text: Knock, knock. Who’s there? Hike.\n",
      "Can you continue the joke based on the given text material \"Knock, knock. Who’s there? Hike\"?</s>\n",
      "<|assistant|>\n",
      "Sure! Knock, knock. Who's there? Hike. Hike who? Hike up your pants, it's cold outside!</s>\n",
      "<|user|>\n",
      "Can you tell me another knock-knock joke based on the same text material \"Knock, knock. Who's there? Hike\"?</s>\n",
      "<|assistant|>\n",
      "Of course! Knock, knock. Who's there? Hike. Hike who? Hike your way over here and let's go for a walk!</s>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Example of formatted prompt\n",
    "print(dataset[\"text\"][2576])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CyuLZGizDqUB"
   },
   "source": [
    "## Models - Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 241,
     "referenced_widgets": [
      "1b4de2592d454a7ba7560ea6849dc6ba",
      "16f636aa216e44a1b9ebb66961b45361",
      "d55c8156644740c48012383eb903b7c5",
      "b40ba0141f194f36b95882644fb2a41e",
      "e8499c18c273492d8fb55fdd50f8f9d2",
      "db0c43c383d3479989e3a307a8c635fd",
      "2e5ccfc7cf4e45cbbdf74add6bf2fdf2",
      "b66b520fbdf84c78adc2f16910ed7a2a",
      "bb4131ce7126465ba7f0107f51d4bd5a",
      "865041ec26e3499cb66da96354bc9a7a",
      "32a2bd764cae40d59467ea71dccea11d",
      "b226cae1a8fc4b8384e0f5fc67d94d89",
      "220334bb55104a01b4e112a05a79db77",
      "a956b9fa522e4d0fafc8fb415ac109ec",
      "a6ae124d59814210913da1e50e82a90b",
      "c03dd90be661408c8d94238ce901081d",
      "23399b989c4449e0ae28c9932b104ba4",
      "16a053e1fc61408a97c650b58fd56913",
      "373a8dedb2b6498c8831f01912e87f10",
      "ec9bf18f5baa45d6a90c1fd2be0e941f",
      "88e2d8f90d814e61b2b8993f12f24d0c",
      "38e7e69833ec4298872afb91e3ba76a0",
      "da07fb8f4a204a4aa9125ddb89526992",
      "3071cd29353c4137a104b974efb903fb",
      "e650428153f441c19a9ce60c7f36b726",
      "6f6d884827dd403280cc2bd44febf8f7",
      "505f2e0ae82f4d5190e7f0a61c693c7d",
      "3ef9ff15982048bcbd9667af97be7a91",
      "f7e0a4c846264a3ebf8bec09a6520674",
      "ac64efb662ca4a2593609ee2797cb188",
      "2211c8ccb3eb4d0d8e5d7ce14e09c843",
      "dea87b44934648cb86d95f4f081005b1",
      "cf4f55ef9e2f450b891149a5cf3d190d",
      "4f88583ce59b467cb9791d09f1d57afd",
      "825cf142c3a842ccb3f7001138a11930",
      "12959d7e56374d8ca49ff18544b4db7e",
      "f301f0072ff649f3a75dea4ed687c294",
      "37273f57135141deb8cb53d858669834",
      "3c0da0f585f64a489a28fb6afa1e6f5f",
      "5b4f672bcb7546eca0e2dbce43900798",
      "e918450ee4b74d03bab029dd7230728e",
      "57040592ffb3433fbdfe90ea1a52d1ab",
      "0ab39d8c544a43bcade2ab94f2a61a0a",
      "fedef8865ad740d0a5cf4167b0067bf6",
      "7ab5f2a459764205ae4263b22f64a7aa",
      "f6092670ea49450e93fbf62eab75d996",
      "6c042632a5b94c5bb9b387ad17920908",
      "7ddca642a1bb4605ad508b1d56c3a61f",
      "1d44d180cee1440ca280e7004aa9bbda",
      "33fd53bddde646cf99019f98e04cb1d5",
      "744008ceeebd4809969b054d3a09d6c5",
      "1deb1756e91941af91bc2f404b5c52c9",
      "9e2a150555434cf39df5da21932a57bf",
      "397231517f4e46fbaef4408d36ecb1c8",
      "1c5d2c4d4d74406fbd4a3fb5b528eed9",
      "a641c51801964d749c11be9ecdaf8749",
      "4375456a00da42c4af307d3b0680ec02",
      "51778001c9214474870749b4af4e2d23",
      "63ae62bab5394d848265983975840f53",
      "e1337ad30ae74bf28c4f88ccecb24c06",
      "a80691a3b01a4c409c49833aa4e2c4ea",
      "f51f84cc3c9e41ea82720a9ed41cfbe4",
      "ec18963535394ba2acef196aff2603b9",
      "9666d1392eb5438b84b76dcfd3222992",
      "cd70eac32fea43a3bfb08cf52d74ae4b",
      "4f4660a0cec2426e8d7ff0b673de28be",
      "5f59b68201d94411b978417aa902e747",
      "e45f18c8c9b447eda9c64a47df8d4881",
      "179ee7d876784842b278c92ce0ac4f7d",
      "e664c66f54e9465c8bd404b2ae5aaf4e",
      "a55c2efc970b481a8fb1ad771c6fc1ee",
      "686140ba3e5b428088de8d94b6e3c707",
      "601b905edc4d410184ee34f83d993971",
      "b1738e5723ca4a4ca82f9038897ce547",
      "aa3dc6d6d4d44b2b8d25ab7fbb4bf224",
      "ad68d7d2d6ef4d90aef6d7536d61bd5b",
      "9fe1ce51c081490bbbfc9c3688bc3860"
     ]
    },
    "executionInfo": {
     "elapsed": 22033,
     "status": "ok",
     "timestamp": 1719390623304,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "M95Y207T7wSp",
    "outputId": "5ed465a8-9fa9-4c8e-8b50-bdc3630bbef1"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b4de2592d454a7ba7560ea6849dc6ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b226cae1a8fc4b8384e0f5fc67d94d89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da07fb8f4a204a4aa9125ddb89526992",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f88583ce59b467cb9791d09f1d57afd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ab5f2a459764205ae4263b22f64a7aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a641c51801964d749c11be9ecdaf8749",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f59b68201d94411b978417aa902e747",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
    "\n",
    "model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n",
    "\n",
    "# 4-bit quantization configuration - Q in QLoRA\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,  # Use 4-bit precision model loading\n",
    "    bnb_4bit_quant_type=\"nf4\",  # Quantization type\n",
    "    bnb_4bit_compute_dtype=\"float16\",  # Compute dtype\n",
    "    bnb_4bit_use_double_quant=True,  # Apply nested quantization\n",
    ")\n",
    "\n",
    "# Load the model to train on the GPU\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    device_map=\"auto\",\n",
    "\n",
    "    # Leave this out for regular SFT\n",
    "    quantization_config=bnb_config,\n",
    ")\n",
    "model.config.use_cache = False\n",
    "model.config.pretraining_tp = 1\n",
    "\n",
    "# Load LLaMA tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)\n",
    "tokenizer.pad_token = \"<PAD>\"\n",
    "tokenizer.padding_side = \"left\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t1iGIch-sAMC"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "86o1T5n4DziD"
   },
   "source": [
    "### LoRA Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0tYs1ZhYDyw9"
   },
   "outputs": [],
   "source": [
    "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
    "\n",
    "# Prepare LoRA Configuration\n",
    "peft_config = LoraConfig(\n",
    "    lora_alpha=32,  # LoRA Scaling\n",
    "    lora_dropout=0.1,  # Dropout for LoRA Layers\n",
    "    r=64,  # Rank\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=  # Layers to target\n",
    "     ['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']\n",
    ")\n",
    "\n",
    "# prepare model for training\n",
    "model = prepare_model_for_kbit_training(model)\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zhbh7kKuD24o"
   },
   "source": [
    "### Training Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TwxZkx80G6bO"
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "output_dir = \"./results\"\n",
    "\n",
    "# Training arguments\n",
    "training_arguments = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    per_device_train_batch_size=2,\n",
    "    gradient_accumulation_steps=4,\n",
    "    optim=\"paged_adamw_32bit\",\n",
    "    learning_rate=2e-4,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    num_train_epochs=1,\n",
    "    logging_steps=10,\n",
    "    fp16=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RtwIo5a0D6f1"
   },
   "source": [
    "## Training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "1b741e85fcf1458f9875c12a9640dfee",
      "d034f23fd4df4e3296477d8dd76be5b1",
      "bee0b0ce2fb84c5eb67a04ced69752d1",
      "4d39d09e1e2648b1b5295f192e9ad356",
      "3f733eb54fe54d879c97dba7a5204ddd",
      "e9dab506c3b242d7b6228394ada6084b",
      "e7d4893b696c4941bf29d349eb2ceabb",
      "6d7ee17aa7024c8088c374781348f9f0",
      "7d8478e66e394f0fb077853e5319ee6a",
      "630e317f036f41f4a9852f7df81eef83",
      "eb35394940c74f60abe2daaeb243fa88"
     ]
    },
    "executionInfo": {
     "elapsed": 774977,
     "status": "ok",
     "timestamp": 1719391399990,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "B2D7RVihsE7Z",
    "outputId": "1a9f8125-6d39-410e-ff94-9a9ac493ff25"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '1.0.0'.\n",
      "\n",
      "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n",
      "  warnings.warn(message, FutureWarning)\n",
      "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1965: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:269: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:307: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b741e85fcf1458f9875c12a9640dfee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/3000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:397: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='375' max='375' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [375/375 12:45, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.670600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.475400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>1.451400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>1.487800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>1.477900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>1.390500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>1.495200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>1.450300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>1.427900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.404400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>1.414400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>1.377500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>1.332100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>1.497000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>1.347000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>1.411500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>1.454000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>1.324500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>1.419300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.474900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>1.404600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>1.342100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>1.361100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>1.387300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>1.353700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>260</td>\n",
       "      <td>1.345800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>270</td>\n",
       "      <td>1.465400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>280</td>\n",
       "      <td>1.434000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>290</td>\n",
       "      <td>1.387600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.376200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>310</td>\n",
       "      <td>1.395000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>320</td>\n",
       "      <td>1.437900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>330</td>\n",
       "      <td>1.387200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>340</td>\n",
       "      <td>1.388100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>1.313600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>360</td>\n",
       "      <td>1.444300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>370</td>\n",
       "      <td>1.452000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from trl import SFTTrainer\n",
    "\n",
    "# Set supervised fine-tuning parameters\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    train_dataset=dataset,\n",
    "    dataset_text_field=\"text\",\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_arguments,\n",
    "    max_seq_length=512,\n",
    "\n",
    "    # Leave this out for regular SFT\n",
    "    peft_config=peft_config,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer.train()\n",
    "\n",
    "# Save QLoRA weights\n",
    "trainer.model.save_pretrained(\"TinyLlama-1.1B-qlora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tsIBfv1PsId-"
   },
   "source": [
    "### Merge Adapter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M6cPdde4Z-ks"
   },
   "outputs": [],
   "source": [
    "from peft import AutoPeftModelForCausalLM\n",
    "\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(\n",
    "    \"TinyLlama-1.1B-qlora\",\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "\n",
    "# Merge LoRA and base model\n",
    "merged_model = model.merge_and_unload()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jPRYGimIsM2-"
   },
   "source": [
    "### Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6781,
     "status": "ok",
     "timestamp": 1719391410095,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "15dJC3ZrdVnK",
    "outputId": "3095ed46-5bb8-4288-b3a0-05d7daeaefc4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|user|>\n",
      "Tell me something about Large Language Models.</s>\n",
      "<|assistant|>\n",
      "Large Language Models (LLMs) are a type of artificial intelligence (AI) that can generate human-like language. They are trained on large amounts of data, including text, audio, and video, and are capable of generating complex and nuanced language.\n",
      "\n",
      "LLMs are used in a variety of applications, including natural language processing (NLP), machine translation, and chatbots. They can be used to generate text, speech, or images, and can be trained to understand different languages and dialects.\n",
      "\n",
      "One of the most significant applications of LLMs is in the field of natural language generation (NLG). LLMs can be used to generate text in a variety of languages, including English, French, and German. They can also be used to generate speech, such as in chatbots or voice assistants.\n",
      "\n",
      "LLMs have the potential to revolutionize the way we communicate and interact with each other. They can help us create more engaging and personalized content, and they can also help us understand each other better.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Use our predefined prompt template\n",
    "prompt = \"\"\"<|user|>\n",
    "Tell me something about Large Language Models.</s>\n",
    "<|assistant|>\n",
    "\"\"\"\n",
    "\n",
    "# Run our instruction-tuned model\n",
    "pipe = pipeline(task=\"text-generation\", model=merged_model, tokenizer=tokenizer)\n",
    "print(pipe(prompt)[0][\"generated_text\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9JNfYZe9vCb8"
   },
   "source": [
    "# Preference Tuning (PPO/DPO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ar2h9kZ9qmEG"
   },
   "source": [
    "## Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 246,
     "referenced_widgets": [
      "d330d84ac98a4d14b51ffab13277e501",
      "2ed3f903f58e42b4a10c8937e7e8cdf5",
      "415333d658284e5f9566065f9bfc4808",
      "76038e4dce21442caa069273e8c22e42",
      "914814f0575948c688991533f85e59dc",
      "e78efc4e689745d69c6d91ac71d51f39",
      "f113f1d6bf9244c98c89f5610e371431",
      "67066b2143a74269bb15a97ace2f0ae2",
      "45b29b925a224dbebf61f67bf359f393",
      "f6a86cfcc1d347e5b182a9abf57b15b4",
      "9dedc5f35e9241a6b9aff057c1c6ef9b",
      "73e8b10b8ee349d1b91dff30f885e335",
      "79da0fb8fad247c1a5e6a5b8bde4d498",
      "c6ff9de93d92425481c29123ac68bf76",
      "c800d5aacd2f4844856c505f11e09e56",
      "0c5019fa5fc74a81a6b11afc49fe135c",
      "c9c3dad975c4436499c3ac3df06bad39",
      "78fb4c59298b4b10b3dd3d5707ae0d81",
      "a7ed257939254f11b1494bf7ae6d42f9",
      "7a856d774904424699aec1f2e9479016",
      "149367fc059e47f990dde648af05d17c",
      "3245f7bd4eb244b587eb15e3cd0b1d80",
      "063c3f5d0ed64ef4901efb5a4fd64149",
      "d571423cdf204acc8ce3e7a4da3e2526",
      "06220ed92658447fb48ade092c4bb36d",
      "6d05c08c79694bed8345d28bdfe19032",
      "54724f48c14b43f5ba4459459670334a",
      "f5d383934bbb4f758336f08b567f0824",
      "64d85e89ed7d46a981d2a00c47bdf8b2",
      "35a0f7cdae6f4a61acb0dffe5e698130",
      "850abb487dce4b1aa09f8094bc447a9b",
      "1aeb4f6ec0234ed3a6be1ac75244ae8e",
      "93156ea1f4a14de392c28e0b489b4290",
      "a29338f36ee34153a880c3f5fb985616",
      "50de9fd7ffd544559ebbd078ef12345f",
      "4036016f7aad43449ad70cd76f40c5eb",
      "9bfee60d3ab5416e983c42ae6ecdd0e0",
      "ff17fc6260e547268b518231cae50451",
      "80e75187ab95457795532aa6c7b00d76",
      "261b9119ab994265aac43d6b80ffc90d",
      "92ef80ab57e44151a0e9419639cb9a34",
      "bb0813f49acd457cadc27d2384e9274f",
      "bd41907a546e40059723a07c53f39339",
      "d30f87bdd81242e989c97a14ec2f98f3",
      "a98b46f3843b4724b6fad82fac16e219",
      "c1f7130a16d844ed91aa1c97c36f18a1",
      "4715e099e4454e9f9c54e0572d0508d4",
      "e9d5215839c44dbcb782855769ef3d0b",
      "c96ff8419847405889ef3b88a4684739",
      "fe77dd4438fb4de389bbf33ec23bf8b4",
      "3e34500a435a457c9cb321b4b258f76e",
      "30cfc3189acb4db7954e0b6efc9a5645",
      "530d30bd4d524415bec77c5d1725a4ac",
      "3804bdf7cf5f4d9c87ce42c47fadfe1a",
      "27c3bc5002ba49ef81d6c20b40d5719f"
     ]
    },
    "executionInfo": {
     "elapsed": 4958,
     "status": "ok",
     "timestamp": 1719391415052,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "UlbPVO_aac33",
    "outputId": "a2c446d6-e410-4d17-eb98-96b21264e0e9"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d330d84ac98a4d14b51ffab13277e501",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/10.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73e8b10b8ee349d1b91dff30f885e335",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/79.2M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "063c3f5d0ed64ef4901efb5a4fd64149",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/12859 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a29338f36ee34153a880c3f5fb985616",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/12859 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a98b46f3843b4724b6fad82fac16e219",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/5922 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['chosen', 'rejected', 'prompt'],\n",
       "    num_rows: 5922\n",
       "})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "def format_prompt(example):\n",
    "    \"\"\"Format the prompt to using the <|user|> template TinyLLama is using\"\"\"\n",
    "\n",
    "    # Format answers\n",
    "    system = \"<|system|>\\n\" + example['system'] + \"</s>\\n\"\n",
    "    prompt = \"<|user|>\\n\" + example['input'] + \"</s>\\n<|assistant|>\\n\"\n",
    "    chosen = example['chosen'] + \"</s>\\n\"\n",
    "    rejected = example['rejected'] + \"</s>\\n\"\n",
    "\n",
    "    return {\n",
    "        \"prompt\": system + prompt,\n",
    "        \"chosen\": chosen,\n",
    "        \"rejected\": rejected,\n",
    "    }\n",
    "\n",
    "# Apply formatting to the dataset and select relatively short answers\n",
    "dpo_dataset = load_dataset(\"argilla/distilabel-intel-orca-dpo-pairs\", split=\"train\")\n",
    "dpo_dataset = dpo_dataset.filter(\n",
    "    lambda r:\n",
    "        r[\"status\"] != \"tie\" and\n",
    "        r[\"chosen_score\"] >= 8 and\n",
    "        not r[\"in_gsm8k_train\"]\n",
    ")\n",
    "dpo_dataset = dpo_dataset.map(format_prompt, remove_columns=dpo_dataset.column_names)\n",
    "dpo_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AkCJ4CO5sQG6"
   },
   "source": [
    "## Models - Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5934,
     "status": "ok",
     "timestamp": 1719391420979,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "7YMmilm7c1-P",
    "outputId": "fbf5e75b-cf63-4ac6-b1b5-514cddceb842"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py:325: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from peft import AutoPeftModelForCausalLM\n",
    "from transformers import BitsAndBytesConfig, AutoTokenizer\n",
    "\n",
    "# 4-bit quantization configuration - Q in QLoRA\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,  # Use 4-bit precision model loading\n",
    "    bnb_4bit_quant_type=\"nf4\",  # Quantization type\n",
    "    bnb_4bit_compute_dtype=\"float16\",  # Compute dtype\n",
    "    bnb_4bit_use_double_quant=True,  # Apply nested quantization\n",
    ")\n",
    "\n",
    "# Merge LoRA and base model\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(\n",
    "    \"TinyLlama-1.1B-qlora\",\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    "    quantization_config=bnb_config,\n",
    ")\n",
    "merged_model = model.merge_and_unload()\n",
    "\n",
    "# Load LLaMA tokenizer\n",
    "model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=False)\n",
    "tokenizer.pad_token = \"<PAD>\"\n",
    "tokenizer.padding_side = \"left\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iidCbaXMs1O4"
   },
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "m6IfkvLkylVD"
   },
   "outputs": [],
   "source": [
    "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
    "\n",
    "# Prepare LoRA Configuration\n",
    "peft_config = LoraConfig(\n",
    "    lora_alpha=32,  # LoRA Scaling\n",
    "    lora_dropout=0.1,  # Dropout for LoRA Layers\n",
    "    r=64,  # Rank\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\",\n",
    "    target_modules=  # Layers to target\n",
    "     ['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']\n",
    ")\n",
    "\n",
    "# prepare model for training\n",
    "model = prepare_model_for_kbit_training(model)\n",
    "model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lk-cEEd8nk27"
   },
   "outputs": [],
   "source": [
    "from trl import DPOConfig\n",
    "\n",
    "output_dir = \"./results\"\n",
    "\n",
    "# Training arguments\n",
    "training_arguments = DPOConfig(\n",
    "    output_dir=output_dir,\n",
    "    per_device_train_batch_size=2,\n",
    "    gradient_accumulation_steps=4,\n",
    "    optim=\"paged_adamw_32bit\",\n",
    "    learning_rate=1e-5,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    max_steps=200,\n",
    "    logging_steps=10,\n",
    "    fp16=True,\n",
    "    gradient_checkpointing=True,\n",
    "    warmup_ratio=0.1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "c98fc033e551443b94dc8dae31d590bf",
      "b20056d36ac942cc9f8c3a9efc020f36",
      "3f8a810d4e2549c9a4ba2f3f2ee017e8",
      "47768d322ac94e1494d7f4a2f01440b7",
      "656d76b25562479ba3b24f22800a675b",
      "eab02ca086f44759a9b1b00d8f1a1245",
      "d960c3e6a7a04fb8b525dec294da6815",
      "983f9c1d7e47494383099916eed69c0d",
      "0236bebb8dce4b5dade5728304ffb964",
      "9e1a61b6c5f8482eadd024280da208f3",
      "4380e0fb571e41c9a58b09a06a20b853"
     ]
    },
    "executionInfo": {
     "elapsed": 805129,
     "status": "ok",
     "timestamp": 1719392226734,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "Pp3tUXhWm0pE",
    "outputId": "29378dc8-bb8a-435b-8330-e45aa26548c7"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': max_prompt_length, max_length. Will not be supported from version '1.0.0'.\n",
      "\n",
      "Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.\n",
      "  warnings.warn(message, FutureWarning)\n",
      "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/bnb.py:325: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:358: UserWarning: You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:371: UserWarning: You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/trl/trainer/dpo_trainer.py:411: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments we have set it for you, but you should do it yourself in the future.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c98fc033e551443b94dc8dae31d590bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/5922 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "max_steps is given, it will override any value given in num_train_epochs\n",
      "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n",
      "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='200' max='200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [200/200 12:52, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.692400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.678200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.646000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.606300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.595600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.616800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.593700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.531900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.559200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.639000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>0.496500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.586000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>0.630000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>0.590100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.577500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.591000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>0.606900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.627800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>0.668600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.555400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from trl import DPOTrainer\n",
    "\n",
    "# Create DPO trainer\n",
    "dpo_trainer = DPOTrainer(\n",
    "    model,\n",
    "    args=training_arguments,\n",
    "    train_dataset=dpo_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    peft_config=peft_config,\n",
    "    beta=0.1,\n",
    "    max_prompt_length=512,\n",
    "    max_length=512,\n",
    ")\n",
    "\n",
    "# Fine-tune model with DPO\n",
    "dpo_trainer.train()\n",
    "\n",
    "# Save adapter\n",
    "dpo_trainer.model.save_pretrained(\"TinyLlama-1.1B-dpo-qlora\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QFE4OKFvyLMe"
   },
   "outputs": [],
   "source": [
    "from peft import PeftModel\n",
    "\n",
    "# Merge LoRA and base model\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(\n",
    "    \"TinyLlama-1.1B-qlora\",\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "sft_model = model.merge_and_unload()\n",
    "\n",
    "# Merge DPO LoRA and SFT model\n",
    "dpo_model = PeftModel.from_pretrained(\n",
    "    sft_model,\n",
    "    \"TinyLlama-1.1B-dpo-qlora\",\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "dpo_model = dpo_model.merge_and_unload()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6777,
     "status": "ok",
     "timestamp": 1719392237608,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "zAkwJcHYmxr4",
    "outputId": "631aed7c-1e64-4e2c-db73-3e36ddee4e75"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|user|>\n",
      "Tell me something about Large Language Models.</s>\n",
      "<|assistant|>\n",
      "Large Language Models (LLMs) are a type of artificial intelligence (AI) that can generate human-like language. They are trained on large amounts of data, including text, audio, and video, and are capable of generating complex and nuanced language.\n",
      "\n",
      "LLMs are used in a variety of applications, including natural language processing (NLP), machine translation, and chatbots. They can be used to generate text, speech, or images, and can be trained to understand different languages and dialects.\n",
      "\n",
      "One of the most significant applications of LLMs is in the field of natural language generation (NLG). LLMs can be used to generate text in a variety of languages, including English, French, and German. They can also be used to generate speech, such as in chatbots or voice assistants.\n",
      "\n",
      "LLMs have the potential to revolutionize the way we communicate and interact with each other. They can help us create more engaging and personalized content, and they can also help us understand each other better.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Use our predefined prompt template\n",
    "prompt = \"\"\"<|user|>\n",
    "Tell me something about Large Language Models.</s>\n",
    "<|assistant|>\n",
    "\"\"\"\n",
    "\n",
    "# Run our instruction-tuned model\n",
    "pipe = pipeline(task=\"text-generation\", model=dpo_model, tokenizer=tokenizer)\n",
    "print(pipe(prompt)[0][\"generated_text\"])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPaxPKtmt1gCzzuqYr6g2+g",
   "gpuType": "T4",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
